{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# [13、详细推导自动微分Forward与Reverse模式](https://www.bilibili.com/video/BV1PF411h7Ew?spm_id_from=333.788.player.switch&vd_source=cdd897fffb54b70b076681c3c4e4d45d)\n",
    "\n",
    "- [Automatic differentiation in machine learning: a survey](https://arxiv.org/abs/1502.05767)\n",
    "\n",
    "一般深度学习中会搭建一张计算图：\n",
    "![image-20250619143512986](http://assets.hypervoid.top/img/2025/06/19/image-20250619143512986-be27.png)\n",
    "\n",
    "Forward Pass后会获得一个误差函数，通过误差函数计算参数的梯度，然后更新参数的值（Backward Pass）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自动微分、符号微分、数值微分\n",
    "\n",
    "![image-20250619143809741](http://assets.hypervoid.top/img/2025/06/19/image-20250619143809741-959d.png)\n",
    "\n",
    "如图有一个函数表达式，**微分**就是一大串表达式。\n",
    "\n",
    "但是如何在代码中表示出来？ matlab等工具能够自动计算符号微分，但是可能很复杂，单单存储就很耗资源。\n",
    "还有办法计算数值微分，可以使用很小的步长来逼近实际的微分，但是不稳定也不准确。\n",
    "\n",
    "自动微分：以一个有向无环图的形式来一步一步地计算微分。如图中就是forward mode的微分。（这里就是计算l1对x的微分，然后以此计算l2的微分，逐步计算出l4的微分）"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 自动微分的两种模式\n",
    "\n",
    "### 前向模式\n",
    "\n",
    "$$\n",
    "y = f(x_1, x_2) = \\ln(x_1) + x_1x_2 - sin(x_2)\n",
    "$$\n",
    "这里 $x_1 = 1$ 计算 $\\frac{dy}{dx_1}$\n",
    "\n",
    "![image-20250619144919818](http://assets.hypervoid.top/img/2025/06/19/image-20250619144919818-3119.png)\n",
    "\n",
    "\n",
    "前向模式就是通过x计算y，也就是 **输入节点改变** 对 所有中间节点和输出节点的影响。 好处是不会遇到梯度累计的情况。\n",
    "\n",
    "### 后向模式\n",
    "\n",
    "![image-20250619154553911](http://assets.hypervoid.top/img/2025/06/19/image-20250619154553911-ea23.png)\n",
    "\n",
    "前向模式就是通过y计算导数，也就是**输出节点改变** 对 所有中间节点和输入节点的影响。\n",
    "必须计算出所有节点的值之后才能计算，否则会报错。\n",
    "有梯度累计的情况，可能存在链式法则导致一个节点和多个节点相关，因此需要计算并保存起来。如：\n",
    "\n",
    "```\n",
    "z1 = x^2\n",
    "z2 = 2x\n",
    "y = z1+z2 # 这里z1和z2共享x参数\n",
    "```\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 前向后向操作运算量\n",
    "\n",
    "![image-20250619160053298](http://assets.hypervoid.top/img/2025/06/19/image-20250619160053298-2576.png)\n",
    "\n",
    "现在的框架都设置为 Reverse Mode，因为模型都**假设**输入维度大于输出维度。 当然中间节点的大小可能不一样。\n",
    "\n",
    "PyTorch 自动计算梯度主要且默认采用的是 反向模式 (Backpropagation)，支持 正向模式，需要通过特定的函数 (torch.jvp 等) 来显式地使用它"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 对偶数\n",
    "\n",
    "**对偶数（Dual Numbers）** 是一种扩充实数的方式，类似于我们用 $i$ 扩充实数得到复数。一个对偶数写作：\n",
    "\n",
    "$$ z = v + \\dot{v}\\epsilon $$\n",
    "\n",
    "其中：\n",
    "*   $v$ 和 $\\dot{v}$ 都是普通的实数。$v$ 代表函数的**原始值（primal value）**。\n",
    "*   $\\epsilon$ 是一个很小的数字，其中 $\\epsilon^2 = 0$，但 $\\epsilon \\neq 0$。你可以把它理解为一个“无穷小”的量，小到其平方可以忽略不计。\n",
    "\n",
    "对偶数的主要用途是为**前向模式自动微分（Forward Mode AD）** 提供一个简洁而严谨的数学实现。它的“魔力”在于，对偶数上的基本算术运算，其结果会自动地、正确地计算出导数，完美符合微积分的求导法则。\n",
    "\n",
    "\n",
    "我们来看论文中提到的两个例子：\n",
    "\n",
    "**a) 加法法则：**\n",
    "假设有两个对偶数 $v + \\dot{v}\\epsilon$ 和 $u + \\dot{u}\\epsilon$。它们的和是：\n",
    "$$ (v + \\dot{v}\\epsilon) + (u + \\dot{u}\\epsilon) = (v+u) + (\\dot{v}+\\dot{u})\\epsilon $$\n",
    "你看，结果的原始部分是 $v+u$，导数部分是 $\\dot{v}+\\dot{u}$。这恰好对应了微积分的加法法则：$d/dx(f+g) = f^\\prime + g^\\prime$。\n",
    "\n",
    "**b) 乘法法则（这是最精妙的地方）：**\n",
    "它们的乘积是：\n",
    "$$ (v + \\dot{v}\\epsilon) \\times (u + \\dot{u}\\epsilon) = vu + v\\dot{u}\\epsilon + u\\dot{v}\\epsilon + \\dot{v}\\dot{u}\\epsilon^2 $$\n",
    "因为 $\\epsilon^2 = 0$，最后一项 $\\dot{v}\\dot{u}\\epsilon^2$ 就直接消失了！所以结果是：\n",
    "$$ vu + (v\\dot{u} + u\\dot{v})\\epsilon $$\n",
    "结果的原始部分是 $vu$，而导数部分是 $v\\dot{u} + u\\dot{v}$。这不正是微积分的**乘法法则（Product Rule）**$d/dx(f*g) = fg' + f'g$ 吗！\n",
    "\n",
    "**总结一下它的作用：**\n",
    "\n",
    "通过将程序中的所有数值替换为对偶数，并重载（redefine）所有基本运算（加、减、乘、除、sin、log等）在对偶数上的行为，我们就可以实现前向模式AD。\n",
    "\n",
    "具体流程是：\n",
    "1.  **初始化：** 想要计算函数 $f(x)$ 在某一点的导数，我们就将输入 $x$ 表示为对偶数 $x + 1\\epsilon$。（这里的$1$是因为 $dx/dx = 1$）\n",
    "2.  **计算：** 将这个对偶数代入函数 $f$ 中进行计算。由于所有运算都已为对偶数重载，计算过程中，导数部分会根据链式法则自动地、一步步地传播。\n",
    "3.  **获得结果：** 计算完成后，最终得到的对偶数会是 $f(x) + f'(x)\\epsilon$ 的形式。原始部分 $f(x)$ 就是函数值，而 $\\epsilon$ 的系数 $f'(x)$ 就是我们梦寐以求的导数值！"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### 对偶数例子\n",
    "\n",
    "$$\n",
    "f(x) = x^2 + 2*x\n",
    "$$\n",
    "\n",
    "求 $f(x)$在 $x=2$处的导数(这里 $\\dot{v}=1$)：\n",
    "$$\n",
    "f(2+\\epsilon) = (2+\\epsilon)^2 + 2*(2+\\epsilon) \\\\\n",
    "= 4+4\\epsilon + \\epsilon^2 + 4 + 2\\epsilon = 8+6\\epsilon\n",
    "$$\n",
    "\n",
    "$f(2) = 8, f^\\prime(2) = 6$"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.x.x"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
